{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e8cfc9c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['财 务 精 英 这 样 用 e x c e l', \"that's how the financial elite uses excel .\"]\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b85d1ba",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.0394: 100%|█| 20/20 [17:18<00:00, 51.95s/it\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAASKZJREFUeJzt3Qd8U+X+x/FfFx3QwZ6FspENMmSIIIgiclFxo+IWxIEoCk7UvxbHRRyIXvWC14U4QEWGgAyRIXvvWfbsonTn/3oeTEjapEnbpCfj8369DlknJ89pUvLtM4NMJpNJAAAAvFCw0QUAAABwhKACAAC8FkEFAAB4LYIKAADwWgQVAADgtQgqAADAaxFUAACA1woVH5afny9HjhyR6OhoCQoKMro4AADABWoKt7S0NKlVq5YEBwf7b1BRISU+Pt7oYgAAgBJISkqSOnXq+G9QUTUp5hONiYkxujgAAMAFqampuqLB/D3ut0HF3NyjQgpBBQAA3+JKtw060wIAAK9FUAEAAF6LoAIAALwWQQUAAHgtggoAAPBaBBUAAOC1CCoAAMBrEVQAAIDXIqgAAACvRVABAABei6ACAAC8FkEFAAB4LZ9elNBTMnPy5My5bFFrJdWMjTS6OAAABCxqVOyYufGodB33h4z+cZPRRQEAIKARVOyIjrhQ0ZSamWN0UQAACGiGBpW8vDx58cUXpX79+hIZGSkNGzaU1157TUwmk1cElbTMXEPLAQBAoDO0j8qbb74pkyZNki+++EJatGghq1evlnvvvVdiY2Pl8ccfN6xc0eFh+jKdoAIAQOAGlWXLlsnAgQOlf//++nZCQoJ8++238vfff9vdPysrS29mqampHilXZLkQfZmRTVABACBgm366du0qCxYskJ07d+rbGzZskKVLl0q/fv3s7p+YmKhrW8xbfHy8R4NKZk6+R44PAAB8oEZl9OjRulakWbNmEhISovusvP766zJ48GC7+48ZM0ZGjhxpua2e64mwEhV2Iahk5+VLbl6+hIbQ5xgAgIALKtOmTZOvv/5avvnmG91HZf369TJixAipVauWDBkypND+4eHhevM0c42KkpGTJzEEFQAAAi+ojBo1Steq3Hbbbfp2q1at5MCBA7qJx15QKSvhoReDSZZq/okwrCgAAAQ0Q6sKMjIyJDjYtgiqCSg/39i+IUFBQRIcdOG60UOlAQAIZIbWqAwYMED3Salbt65u+lm3bp2MHz9e7rvvPjFacFCQ5JtMkk9OAQAgMIPKBx98oCd8e+SRR+TEiRO6b8rDDz8sL730knhDUBExSR41KgAABGZQiY6OlgkTJujN2+gWqTyRfKpUAAAwDMNZiqxRUX1UjC4JAACBi6DiJKiofioAAMAYBBUH/skpBBUAAAxEUHGAGhUAAIxHUHEg5J+JVOhLCwCAcQgqDpgnfKNGBQAA4xBUipidVjF4klwAAAIaQcUBalQAADAeQcUB5lEBAMB4BBUnQYUp9AEAMA5BxQHzos40/QAAYByCitOmH4IKAABGIag4nfDN6JIAABC4CCrOptAnqQAAYBiCigMhdKYFAMBwBBUHGJ4MAIDxCCoOsHoyAADGI6g4QGdaAACMR1BxgHlUAAAwHkHFSWda5lEBAMA4BBUnqyfnsXoyAACGIag4wOrJAAAYj6DiAFPoAwBgPIKKA4z6AQDAeAQVBxj1AwCA8QgqTmpU8qhSAQDAMAQVB5hCHwAA4xFUHGAKfQAAjEdQcYDOtAAAGI+g4kDIPxOpUKMCAECABpWEhAQ9A2zBbfjw4eItE74xjwoAAMYJNfC1ZdWqVZKXl2e5vXnzZrnqqqvk5ptvFqMxhT4AAAEeVKpWrWpze9y4cdKwYUO54oorxGhMoQ8AQIAHFWvZ2dny1VdfyciRIy21GQVlZWXpzSw1NdVj5WEKfQAAjOc1nWlnzJghycnJcs899zjcJzExUWJjYy1bfHy8x8oTbOlM67GXAAAAvhJUPv/8c+nXr5/UqlXL4T5jxoyRlJQUy5aUlFQGw5NJKgAABHTTz4EDB2T+/Pny008/FblfeHi43sqyjwpT6AMAEOA1KpMnT5Zq1apJ//79xVswhT4AAMYzPKjk5+froDJkyBAJDfWKCh6NKfQBADCe4UFFNfkcPHhQ7rvvPvEmIUyhDwCA4Qyvwujbt69XDgGmMy0AAMYzvEbFWwX/85PxxhAFAECgIKg4wBT6AAAYj6DiAFPoAwBgPIKKk860NP0AAGAcgoqTph9G/QAAYByCigOM+gEAwHgEFad9VIwuCQAAgYug4nT1ZJIKAABGIag4m0KfKhUAAAxDUHGAKfQBADAeQcVJZ9q8fGZ8AwDAKAQVBzYfSdGXXyw/YHRRAAAIWAQVBxbtOGl0EQAACHgEFQAA4LUIKgAAwGsRVFyQlZtndBEAAAhIBBUXMPAHAABjEFQc6Nu8uuV6LkkFAABDEFQc6JhQyXKdnAIAgDEIKk6m0FeoUQEAwBgEFQfKhV780ZzPyZOcPMIKAABljaDiwKD2dSzXu7+5UBo/P1u+XMEstQAAlCWCigPlw0ML3ffijM2GlAUAgEBFUCmBE6mZsv1YqtHFAADA7xFUimnZ7lPS6Y0Fcs2EP+Xg6QyjiwMAgF8jqBTTHZ+ttFzfevTCCssAAMAzCCqlkJnDSCAAADyJoFIK87YdN7oIAAD4NYJKKYQFW80KBwAA3I6gUgphIfz4AADwJL5pXZz0zZ5QggoAAB5l+Dft4cOH5c4775TKlStLZGSktGrVSlavXi3eINRJ005YCE0/AAB4UuHpV8vQ2bNnpVu3btKrVy+ZPXu2VK1aVXbt2iUVK1YUb3AyPavIx0ODDc95AAD4NUODyptvvinx8fEyefJky33169d3uH9WVpbezFJTPTs7bHpmbpGPh4VSowIAgCcZWiXwyy+/SIcOHeTmm2+WatWqSbt27eTTTz91uH9iYqLExsZaNhVyPOlYamaRj+89ec7m9q7jaTL5r32Sncv8KgAA+HxQ2bt3r0yaNEkaN24sc+fOlWHDhsnjjz8uX3zxhd39x4wZIykpKZYtKSnJo+WrYGdhQmvzttrOo3LVu0vklV+3yvh5Oz1aLgAAAoWhTT/5+fm6RuWNN97Qt1WNyubNm+Xjjz+WIUOGFNo/PDxcb2WlQoTzH4/JZJKgINsmoI8X75HR/Zp5sGQAAAQGQ2tUatasKc2bN7e575JLLpGDBw+KN3AlbKzcd0bmbD4qKRk5ZVImAAACiaE1KmrEz44dO2zu27lzp9SrV0+8Qfu6FaVB1fKF+qJYu+0/K/Rlm/i4MiwZAACBwdAalSeffFJWrFihm352794t33zzjfznP/+R4cOHG1ksG7OfuNyl/TYkJXu8LAAABBpDg0rHjh1l+vTp8u2330rLli3ltddekwkTJsjgwYPFW4SHhkjzmjFGFwMAgIBkaNOPct111+nNm1UsH2Z0EQAACEhMreqCOnFRRhcBAICARFBxAUONAQAwBkHFBRXLlyvW/tVjym6uFwAA/BlBxQNYrBAAAPfgG9VFU+7t6PK+17Ss4dGyAAAQKAgqLmpey/UhyuGh/FgBAHAHvlFdVC06QipGuTZM+VhKpvyw5pDk5LGKMgAApUFQKYbel1R3ab+f1h2Wp7/fIP9ZstfjZQIAwJ8RVIph8+GUYu3/x/YTHisLAACBgKBSDNuPpRVr/1yafgAAKBWCigdl5RJUAAAoDYKKB+WbTEYXAQAAn0ZQ8aCjyZlyIi3T6GIAAOCzCCoelJaVK51eXyB5+dSsAABQEgSVYmhWI7pEz2M+FQAASoagUgwD29Yu8XObvzRHEkb/Jt+sPOjWMgEA4M8IKsUQFFSy563ef1YysvP09eemb5L1ScnuLRgAAH6KoFIMJR3Ec+fnK21urzlw1j0FAgDAzxFUiiHOxbV+nHlt5la3HAcAAH9HUCmGQe3ryL/a1DK6GAAABAyCSjGUCw2W929vJ6HBJeysAgAAioWgUoadagEAQPEQVEqAmfEBACgbBJUSyCOpAABQJggqJUBOAQCgbBBUAACA1yKoAAAAr0VQAQAAXougAgAAvBZBpQS+fqCzVC5fzuhiAADg9wwNKmPHjpWgoCCbrVmzZuLtujWqIqtf6FOqY5gYOgQAgFOhYrAWLVrI/PnzLbdDQw0vkktUqCoNlVOY4RYAgKIZngpUMKlRo4YEmnyTSYKFpAIAgFf3Udm1a5fUqlVLGjRoIIMHD5aDBw863DcrK0tSU1NtNl9Fww8AAF4eVDp37ixTpkyROXPmyKRJk2Tfvn1y+eWXS1pamt39ExMTJTY21rLFx8eLN2haPbpENSoAAKBoQSYv6tWZnJws9erVk/Hjx8v9999vt0ZFbWaqRkWFlZSUFImJiSnj0oqsPXhWfll/RO7vXl8uf2thsZ67/bVrJCIsxGNlAwDAW6nvb1Xh4Mr3t+F9VKzFxcVJkyZNZPfu3XYfDw8P15u3aF+3ot6UZaOvlAnzd8q01Ydceq73xEMAALyX4X1UrKWnp8uePXukZs2a4mtqxUVKz6bVXN7fRC8VAAC8O6g8/fTTsnjxYtm/f78sW7ZMbrjhBgkJCZHbb79dfFG/ljXk7Ztay/BeDZ3um09OAQDAKUObfg4dOqRDyenTp6Vq1arSvXt3WbFihb7ui9TcKjd3iJdfNhxxui+daQEA8PKgMnXqVPFHy/ectrk9ul8zGTd7u8195BQAAHysj4q/SM/Ktbk99IqGUqnA2kBeNNgKAACvRVDxgNBg5zPOklMAAHCOoOIBwVaL+FzeuIrb+qj8vuVYoWYly/HonQsA8EMEFQ/XqNSMjbC7T0Z2XrGOeTTlvDz05Rq5/dMVhR5buOOEtBw7V351oRMvAAC+hKDiAcF2mn4K3jNh/i59uedkulw1frH8vP5wkcc8mXZxRt6C7p28Sgefx75dV8ISAwDgnQgqBvVRUbUgyjM/bJRdJ9LlianrXT4+HXEBAIGCoOIBIVZB5a7LEuzuc+Zctr48V2CEkCPW2YScAgAIFAQVDweVGv/0UbHqX1tokrjiyiOpAAACBEHFA6xbfszXT6VfqEEpKetoonLKgdPnZOLC3ZKWmVOq4wIA4M28avVkf9GkerTLNSZBJeiXciwlU/q996fuQLv/1LkSlxMAAG9HUPGAjgmVLNdDQxxHkdy8/BIdv8fbCy3XV+yzP68KAAD+gKYfD7DuQ+JsBND2Y6mW62oyN0fhxVGvlPySZR0AAHwCQcUD8qxmiQ0NdvwjVntZTyirJnP7cOFu+/s6SCoMVQYA+DOCigfku1ijYm8a/U+X7JU5m49KRnauDiGvzdwqX688UMQx3FBgAAC8FH1UPKB6dESRs9Sa2asMOZedJ0O/WivXta4pQ7omyOdL9+n7fxzWxW1rBgEA4CsIKh5QsXw5+eXRbhIZFlLkfkWFjJkbj8oN7Wo7fS11BDWwiLwCAPBHBBUPaV0nzuk+5vV+HLEe2VxUE4/ajZwCAPBH9FEx0H+W7C3y8SCrWVasO+haUzUp1nO1bDqU4sYSAgBgLIJKGQkPLf6Peu6WY5bravFCVwz4cGmxXwcAAG9FUCkj17aqWeznTF2VZLl+8EyG3X1OpWc5rG0BAMDXEVTKyEvXNZebLq1jdDEAAPApBJUyHAn0zs1tpGNCRaOLAgCAzyCoAAAA/woqX3zxhfz222+W288884zExcVJ165d5cABx7OoAgAAeDyovPHGGxIZGamvL1++XCZOnChvvfWWVKlSRZ588smSHDJg5NLxFQAAz074lpSUJI0aNdLXZ8yYIYMGDZKHHnpIunXrJj179izJIQNGjoPVkQEAgJtqVCpUqCCnT5/W13///Xe56qqr9PWIiAg5f/58SQ4ZMF66roXRRQAAwL9rVFQweeCBB6Rdu3ayc+dOufbaa/X9W7ZskYSEBHeX0a90ql/J6CIAAODfNSqqT0qXLl3k5MmT8uOPP0rlypX1/WvWrJHbb7/d3WUEAAABKshk8t11d1NTUyU2NlZSUlIkJiZGfMUDX6yW+duOe+z4+8f199ixAQAoy+/vEtWozJkzR5YuXWpTw9K2bVu544475OzZsyU5ZEC5tlUNo4sAAIBPKFFQGTVqlE5DyqZNm+Spp57S/VT27dsnI0eOLFFBxo0bp1cBHjFihPg7q8WOAQCAuzvTqkDSvHlzfV31Ubnuuuv03Cpr1661dKwtjlWrVsknn3wirVu3lkAQTFIBAMBzNSrlypWTjIwLq/nOnz9f+vbtq69XqlTJUtPiqvT0dBk8eLB8+umnUrFiYKyD06ZOnNFFAADAf4NK9+7ddRPPa6+9Jn///bf073+h86YaqlynTvFWCB4+fLh+fp8+fZzum5WVpYOQ9eaLqkSHG10EAAD8N6h8+OGHEhoaKj/88INMmjRJateure+fPXu2XHPNNS4fZ+rUqbq5KDEx0aX91X6ql7B5i4+PL0nxAQCAP/dRqVu3rsycObPQ/e+++26xpuF/4oknZN68eXpGW1eMGTPGprOuqlHxxbBSLsRzi1YH0/0FABDoQUXJy8vT6/xs27ZN327RooX861//kpCQEJeeryaHO3HihLRv397mmEuWLNE1NqqZp+CxwsPD9ebryoUGy9N9m8g7v+90+7F9dlIcAADcFVR2796tR/ccPnxYmjZtammWUbUbv/32mzRs2NDpMXr37q2HNlu79957pVmzZvLss8+6HHh81aNXNvZMUCGpAAACPag8/vjjOoysWLFCj/RR1CKFd955p35MhRVnoqOjpWXLljb3lS9fXk/HX/B+AAAQmEoUVBYvXmwTUhQVMNSkbd26dXNn+VBMEWGe6/8CAIBPBBXVTyQtLc3unChqjpWSWrRoUYmfiwvyafoBAPiREv35rWaifeihh2TlypWi1jRUm6phGTp0qO5QC9fERJS4L7ND+SQVAECgB5X3339f91Hp0qWLHlqstq5du0qjRo1kwoQJ7i+ln6pTMcrtx8ynNy0AwI+U6E/6uLg4+fnnn/XoH/Pw5EsuuUQHFbiuc4NKsvWoe2fXdUeFiqohW3PgrDSpES0xEWHuKBYAAJ4NKs5WRV64cKHl+vjx40tWmgAz6uqmUiMmQhJnb3e6b6f6lSQtM1eublFdJszf5TRoqJWoS+qXDUfkianrJTYyTBaP6ilxUSXvdwQAQJkElXXr1rm0X2m+IANNVLlQefiKhrLuYLLM2XKsyH2bVK8g/3d9K1l38KzToKJqVUJK8Tb8uuGIvkw5nyNtX50nO/+vn56kDgAArw0q1jUmcK9n+zVzGlSC/wmAf2w/4VI/lRApeVLJybNtPzqbkS3VY1xb5gAAAHfiz2QvUL9KeXn4igYu7etKn5G8UnZUKfh8+ucCAIxCUPESFcoVXbllDgtxUc6DSmmDRW5+fukOAACAmxBUvJi9lZAzc/M9PkR5z8lzNrdNLHUIADAIQcVL2IsCd3dJsHr8wh7hIcElCio7jqXJmJ82ybGUzCKf+/P6w3IyLavA8Zy+JAAAHuH+qVHhNmF2hu5ULO98qLC9YNHvvSX6/j0n0mXa0C56CHNWbr5EhNmuUq2GJRc6HkkFAGAQalS8hL3WmmrRhUfaOKpQeWtQ6yKDhfmubf9MMPfg/9ZIsxfnyJHk807LFmKvDQoAgDJAUPFit3SMLxRkGlWNtrvvzR3quNRHJS0rV1/O33ZcX36/+pDTclCfAgAwCk0/XqpqdLieGbagupWj5KmrmkhkuRAZP2+nZGTnWSbaU1OtqIxiXaGy8VCyJM6ynfn2y+X7i9VRlqYfAIBRCCpeaMNLfSWinOPKrsd6N9aXKqgUnBQu75/VrM1u+ni5ZBcYKfTiz1ss1zNznI8iYh4VAIBRaPrxEpUqXOwkGxsVJuGhtp1cXQkQ5q4k1hUgBUOKo+nyi8KKzAAAo1Cj4iVu7RAv6w8mS48mVWzurx0XKYeTz0v/VjULPadgs82FdZZMxQoWOXnOa1RULQ0AAEYgqHgJtejfv29pU+j+35/sIYfOnpemNex3orVXozJz4xF5qEdDl17X5EJfFOumJAAAyhJNP16ufHiow5BSMD+Yb78xa7skncmQWz9Z7vT4CZWjLNfv/Hyl3X1+XHtY+r//pz4mAABliaDiR6ybfC5/a6Gs3HfG6XPWHUyWK95eKFuOpMiyPaft7jNp0R7ZciRVXp251a3lBQDAGYKKDytYo5KTV/wmmtx8kxw4nSEPf7nG6b4nCkytDwCApxFUfJg7O7mqfjDObEhKdtvrAQDgCoKKD2PYMADA3xFUfBg5BQDg7wgqPiw6gtHlAAD/RlDxYV/e31ma1YiWr+7vbHRRAADwCP4k92Ft4+NkzogeRhcDAACPoUYFLrvrsnpGFwEAEGAIKnBZiHmOfgAAyghBBQAAeC1Dg8qkSZOkdevWEhMTo7cuXbrI7NmzjSyST+vdrJrRRQAAwH+CSp06dWTcuHGyZs0aWb16tVx55ZUycOBA2bJli5HFAgAAXsLQUT8DBgywuf3666/rWpYVK1ZIixYtDCuXr1ILBwIA4E+8ZnhyXl6efP/993Lu3DndBGRPVlaW3sxSU/litnYsNdOjxz9zLtujxwcAwOs6027atEkqVKgg4eHhMnToUJk+fbo0b97c7r6JiYkSGxtr2eLj48u8vIHsfE6e0UUAAAQYw4NK06ZNZf369bJy5UoZNmyYDBkyRLZu3Wp33zFjxkhKSoplS0pKKvPyBrLwUMM/LgCAAGN400+5cuWkUaNG+vqll14qq1atkvfee08++eSTQvuqWhe1wRi14yKNLgIAIMB43Z/I+fn5Nv1Q4D0TubFYMwAgoGpUVFNOv379pG7dupKWlibffPONLFq0SObOnWtksQJOXj4RBADgnQwNKidOnJC7775bjh49qjvHqsnfVEi56qqrjCwWHAhlCn0AQCAFlc8//9zIl/c7zWvGyNajnhuy7c6Kl53H0+STxXtlRJ/GEl8pyn0HBgD4FcM708J9gkpY4dGgannZe/Kc0/1Mbuyl0vfdJfpy8+EUmftkD7cdFwDgX7yuMy1KLriEScXV0TwmD3Rl2XE8zf0HBQD4DYKKHylpjUpC5fIu7ZdPp1sAQBkjqPgRezll8j0d5c7L6hb5vEGX1nHp+OQUAEBZI6j4kSA7VSpJZzMkNLjotznExaoYd/ZRAQDAFQQVP2Jv9HBYSLDERYUV+bysXMdr+Hw+pIM83KNBifqomEwmeX/BLlm882TxnggAwD8IKn7kjs71Ct3Xq2k1pwEjLqqcw8d6X1LdssZPfjGTypzNx2T8vJ0y5L9/F+t5AACYEVT8yKD2tWXmY93lx2FdLPdFhAXLda1r6utNq0fbfV7dSlHSs2lVp01KruSU/afOyeu/bZXjqZny/h+7i38SAABYIaj4ERUoWtaOlXIhITbr+DSuHi1/P9dbfn2su93nlQsN1p1uHR9XXK5RGTRpmXz65z5di7LNg5PPAQACA0HFD1kHCnNH2moxETqQqBqXWzvEywv9L3HaEddcy2Ken8WVUT+nz2Xry+3H7M+PkpuXLwdPZxTndAAAAYyZaf08qBRcGVnVuLx5U2s9J0pOnkk6JFR0eJxK5S/0Xbl4CNuk8sniPTJ/23H54r5OElXOtY/Sk9M2yK8bjtjcdyItU6pFR7j0fABAYKFGxQ9Z13wUDCpmwcFBMqxnQ+mYUMnxcf45kLm2JT/f9vHE2dtl1f6z8vWKg5ZRPo488vUa/XjBkKKs3HvGyRkBAAIVQcUPWQcGR0HFnpgI21qRu7ok6Etzq5AaZvzoN2vlzD/NO2ZnM7Llnsl/S/0xsxwee9amY4WeZzb6x40ulxEAEFho+vFDJZ1BtkeTqjJz41GpFRshvzzWXapUCLfpo3IsNVM/HhEWIu/c3MbyvI8W7XHp+G/M2m73/nPZjudxAQAENoKKH6pXOapEz3v9hlbSNj5O+reuaQkpSsE6mSPJ50t0/B/XHirR8wAAgYug4oeqx0TIL492k+iIomekLSg2MkweuPzCLLRFrcqsWpZYoBAAUBYIKn6qdZ04tx2r4Mjl5XtPS4fX57vt+AAAOEJnWjhlb44VRx1jAQBwJ4IKnCrGwCEAANyKoAKnCvZRAQCgrBBU4BQ5BQBgFIIKStRHBQCAskBQgVP0UQEAGIWgAqfoowIAMApBBU4RUwAARiGowClqVAAARiGowClyCgDAKAQVAADgtQgqcIrlBwEARiGoAAAAr0VQgWFVKg2rlvfMgQEAfsPQoJKYmCgdO3aU6OhoqVatmlx//fWyY8cOI4sEO0weSCp/PtNLLm9c1e3HBQD4F0ODyuLFi2X48OGyYsUKmTdvnuTk5Ejfvn3l3LlzRhYLBZjcmFMGta8jU+7tKPGVouS+bvWtXoOeMACAwkLFQHPmzLG5PWXKFF2zsmbNGunRo4dh5YItd0aITvUrSs+m1fT16IiLH798k0gIw6ABAN4UVApKSUnRl5UqVbL7eFZWlt7MUlNTy6xsgSwjO89tx6oVF2m5Hmy1iFBevklCWFQIAOCtnWnz8/NlxIgR0q1bN2nZsqXDPi2xsbGWLT4+vszLGYjmbj7mtmM1rhZtuW4dTPJp+gEAeHNQUX1VNm/eLFOnTnW4z5gxY3Sti3lLSkoq0zIGqr4tqhf5+Ds3t3F6jITKUTLr8culRmyE5b6QINsaldI6nHxe3pu/S06nX6x1AwD4Nq8IKo8++qjMnDlTFi5cKHXq1HG4X3h4uMTExNhs8LyKUeWKfPymSy++Z7GRYXb3qVMxSprXsn2/gq0+fXlF1Khk5+bLmXPZNvedz86Td+bukC+X75fdJ9L0fYM/XSHvzt8pT0xdX/QJAQB8hqF9VNRIj8cee0ymT58uixYtkvr1L44CgXeu9dMmPk4On82QU+m2wcGsZe0Y+Wv36SKPYa9GJb+IGpUBHyyVHcfT9JBmNVpImb7usHy4cLdln/3j+sv+0xn6+tLdp1w8MwCAtws2urnnq6++km+++UbPpXLs2DG9nT9/3shioYiQMX1YV1k+prfltgoP1tYfTJaZj3V36bi2fVQc76dCijJ3y8W+Mn8VCCMJo39z6TUBAL7F0BqVSZMm6cuePXva3D958mS55557DCoVCgqSIJuROsESpGsw7OnWqIq0rB1b+Bh2qlSCitlHxXr/7Lx8l8oOAPBthjf9wPvZa7Zx5NpWNe3eX75ciMNaFRVS1Kgf1fxjPWR58+EUue6DpZbb5oemrUqSeVuPF1mOgscCAPgmr+hMC+8WGWY/ZNhzab2Khe6rUzFSXryuud39zf1Uhn21Rq789yLJzLk4Z4t1SFHMseOZHzc6Lcfrs7a5XGYAgPfyqgnf4J3UmjxNq0fLpQmFQ4jZhpf6SvL5bEtnV2tLn73S4fP0yJ88kbUHk/XtP3edkquaV5eVe+11yHW9huTzpfschiMAgO8gqMCpyHIhMvfJopc0iI0K01tB93cveiSX9cgfZfWBMzqo3PqfFYX2VS05787b6XK5txxJkRa1CveXAQD4Dpp+4FE5Tjq9FuxHMnPDUYf7qhqV9xbscvm1+79v23QEAPA9BBV4RI8mVfXlbR3rFrlfwfV91Oyy7V+bZ3ffF2ZsdmMJAQC+gKYfeMSUezpKWmau3eagopp+lIKz0AIAAhc1KvAI1aTjLKQUd+gzACDwEFRgKEdT8QMAoBBUAACA1yKowK+lZeYYXQQAQCkQVODXJi7c45bjzN50VH5ef9gtxwIAuI5RP/Brp9KzSn2MjOxcGfb1WsssvZXKl3NDyQAArqBGBX7tXFZuqY8xd8sxtx4PAOA6ggr8Wqob+qjk5F5c5XvPyXTZfSKt1McEALiGoAK/8u+b29jc/mt34cUNi8skF4PKPZNXSZ/xS3RzEADA8wgq8Elv3dTa7v3Na8W4/bVOpBbu55J6nqACAGWBoAKfUzU6XG7pEC/3dkso9NglNWOke6MqHu+QW3CNIgCAZxBU4HNM/7TEvDyghewf11+aVK9g8/hXD3R26+vVqRhV6L6JC3e79TUAAPYRVOBzCq4PNP2Rbvryhna1LffVjovUlw90r1+sY69PSpaPF++RvHxTkc1JU5btL26xAQAlwDwq8GqvDWwhL/68xea+go0u5cNDdc2KtQFtaunAUVzXT/zLsqrzgz0a2NTgAADKHjUq8Gp3dUmQxBtb2dyXULm80+eF/tOHJNeqZqQ4Xp+1TQ9FvnCM/BIdAwBQegQVGOq/93Rwus/tnerKhpf7yk+PdJX+rWvKu7e1dfocc2dX6yYcs+zcfPlyxQHZd+qcw32U+6as0pe5eVSpAIBRaPqBoa5sVt2l/WIjw6R93YrS/o6KLu1vrlHJys0r9NikRXvk3fk79fW3b2otz0/fLJ8N6SA9mlS12e/A6Qx9+eS09S69JgDA/ahRgV/6fetxfTlt9aFCj5lDijLqh42SnZcvD3252u5xVG1LWiZzpgCAUQgq8Fpj+jUr8XM3HU4p1v7BBYcS/WPJrpMlLgMAoPRo+oFXua51TTmemilv3NBKGlePLrPXzcjOs/RZsXbv5Av9VAAAxiCowHC1YiPkSEqmvv7hHe3dcsz4SpGSdOZ8sZ7T651FbnltAID70PQDw333cBfdkfUbN84oO6h9Hcv1hNG/yZPfrZe0zBzZeZyVjwHAl1CjAsPFV4qS/93Xya3HPJ9tO9pn+rrDenOn/HyTBLPmDwB4FDUq8Evt6ro2jLk08pmyFgD8O6gsWbJEBgwYILVq1ZKgoCCZMWOGkcWBHwkL8XxNRx5BBQD8O6icO3dO2rRpIxMnTjSyGPBDYSGe/2iTUwDAz/uo9OvXT2+Au4WWQY0KTT8A4Hk+1Zk2KytLb2apqamGlgfe61R6tsdfw9EaQSWRm5cvD/5vtbSqHSsj+zZ123EBwNf5VGfaxMREiY2NtWzx8fFGFwleKjnD80GltDlFjRoy/VMr8+euU7Jwx0l5/4/d7ikcAPgJnwoqY8aMkZSUFMuWlJRkdJEQ0H1UTKUKKTd89JfUHzNLth1NlfSsi+sJrdx72k0lBADf51NBJTw8XGJiYmw2wJ4+l7i2KrMnm36SzmTIkp0n5dr3/pR1B8/aPPbGrG2y4dCF9Yj6vfenzbGen7HZQyUGAN/jU31UAFdFhHk+gxeVU9YcOCODJi233L790xWy/bWLHcc/W7rPZn/roJKTl+/uogKAzzK0RiU9PV3Wr1+vN2Xfvn36+sGDB40sFvxAdERYqZ4/8qomDh8L+Wc22qJG/fyw5pDN7cycfJmx7rD8vP6wHE0pvAbRjPWHPdJJFwB8naE1KqtXr5ZevXpZbo8cOVJfDhkyRKZMmWJgyRAIdr/eT9Iyc6Xda/Ms93374GVyab2KUi40WMbP22m5/+M728vIaRvk/dvaybCv10iek6Bi76ER310I5PaozrRmh84WbzFFAPBnhgaVnj17lqpDIlAaoSHBUrF8Ocvt/q1rSpeGlS23b+8UL9/+naTDi7q/b/Maem0fNYuyiKnIpp+cPD7XABBwnWmB0hrSpZ7DxxpVrWBzO/HG1rIv8VpLeDEvQGheh1CN3LF29ly2JXiv2n/G3UUHgIBEZ1oEjDZ1YuWVgS2lZe1YqR4TUehx1dxT0IXaEynU30TJyM7TwUTts2zPKbnj05X6/v3j+svBMxkeOQcACDQEFQQEFR7Mbu5gf6LAUHNViYuunrBEL364bHRvS0hRzmerHizeISs3T66fuExa146VN29qbXRxAKDYaPoBCozmKQ7VF6Xj6/Nt7luy66R4i4XbT+oJ5b5bzeSIAHwTNSrwe7c6qEEpKC7qYsfaotSvUl72nTrn8PGHv1wjRpu4cLfUiImQE2kX18ZSfWrM/WwAwFcQVOC3Jt/TUWZvPipj/9WiyP3G3dhKlu89LQPb1nLpuCetvvw9oW/z4s+qezj5vNw8aZnc1SVBGlerIG/P3VFonx/XHnLY7AUA3oqgAr/Vq1k1vTlzW6e6enOV9bo8nqAGEy3bfUouTago4aEhLj1n/O875UhKprw5Z7tUshpybW3lvjMEFQA+hz4qgEFWv9BHd/Id3a+Zzf3ztx2XOz5bKS8WY80f6/mIzpy7uHJ0k+oXh1zP2Xys1GUGgLJGUAEMEhZ84dfvocsb2H182mrbafgdyczJk5/WXZyC31q/ljXLrCYIADyBoAIYJCTknwnkStnB9ZPFex0+VtQ0/wDgCwgqQDE92quRW45T3HlbrM3adFR6vbNIUjNz9PBjRz74Y7fNbRY8BOBr6EwLFNOHC22//MsyqKgJ3D5fuk/emnNhVE/n1xfI+RzXJ5ibt/W4XNOyRrFfFwCMQo0K4CZqhWVH5oy4XKY93EU6JlS0O8HcR4PtP3fz4RSZtjrJUhPy+m/bLCFFKU5IUapGuzZXDAB4C2pUADdRqyvXqxwlB07brvMz/pY20qxGjL7+/dCu8tac7VIzLtJmHaFrW9WUz4d0kPu/WG3z3Os+WKovn/lho16LKDv3wjpDJRXyTwded1ATyHV/8w9pXitWPhvSwW3HBQBrBBXATVSn2MWjesm5rFxp8fJcfV+zGtFyY/s6Nvs9c43tcGSz0JCiQ0RpQ4q7jmGmpuVXc7eoLS0zR6Ijwtx2bAAwo+kHKKbtr11T6L5vHuxsuV4+PFTmj+whw3o2lKkPXebyccPKYHr7kgYVFURy8i4+t+fbC2XMT5sst0dMXV+qcqlwd/PHy+SzPx2PYAIQmAgqQDFFhIXoidpGXd3Ucl/XhlVs9mlULVqevaaZy+sHKWWxDo912Cg4imjktPWSdMa22UpRI4tajf1dGj8/Ww6dzdBhYn+B5q0F20+UqlxfLN8vq/aflf/7bVupjgPA/9D0A5RQ+7oXO8a6gzubZQpqXzdO1h5Mliw7r5FyPkce+Xqtvv7T2sMSFxUm61/qq28fST5vs8hi9zcXurVcarK6HcfSdBkAwB6CClBCXRpWlv/e00EaVLk4TX1pVKkQLu5yeeMqevHE7cfSpG6lKMuaQdl2alRUk4u15IwLoeGxb9fJrxuOiCc99OUaWbLzpEdfA4BvI6gApXBls+KvdOyI9XDlkujeqIos3X1KrxqtFmNUQeXbvw/KLR3i5ZkfN+p9cnLzdTOPeqV+rS5Mr7/zeHqhY83ZfLRYIUUFo5KwF1JU81SYk47FAAIHQQXwEiYp3ayxk+/tKEeTM6Vu5Sh9u2p0uDzeu7G+fjwlU18+9f0Gm+dMuLWt3WMN/epCU5Cr3DlV/7jZ2+XF65q77XgAfBt/tgBeomn1aJf3VZUvy8dcaTMSSdVCmENKQTuOp9m9f8R3pRutY5aVk1+s+VeKombeTRj9m9z40V/y83r7iy0CCBxBJuv14X1MamqqxMbGSkpKisTEXJhQC/BlG5KSZe6WY/LRoj0O97mna4KM/VeLYh1XffGXhX2J1+rLuVuOS/OaMTo4bTmSIjERYXrY9od/7Jb//rVPXvlXC/l+TZJsPpzq0jGtJ8crqf8u3SexkWEy6FLbeW0AePf3N00/gBdpEx+nt+CgILetKaRUiw6XE2lZbjveNS1qyFs3t5ZNh1Jk8GcrLfcv3HFCdhxLlzfnbNe3F4/qKf3fvzC7rrWXf9ni8mudSs/WzViu2HfqnKxPOisDWtfSE+ipUUW7T6TrgPLqzK16H3cHFVVDtGLfaakdFyn1Kpd367EBEFQAr/T01U0dBpWHr2hQ7OO9Oai13DtlVanLVbB2IzzUtvX4vim2SwB8tNBxzZCrMrJz1SsVul/N+aLC16X1LgwT/2ntIRk57UIfnPTMXLmrS4I0e3GOvt2pfiXL81SImrRoj/w8vJsOhSX1ztwdOphtOWJbK3R921oy4bZ2JT4uAFv0UQG81NcPdC40dLlFrRipGRtZoqHURVn6bC/59dHuNvetfqGPJFSOkgFtakm7unGycWzfQk0w5mHPRU2zX1pXvL1Ini7QCVi5/K2FMmjSMpmz+Zg888MGS0hRXvx5i01fmL/3nbFcVyFFGTjxL4evqVrEZ286KusOntXXH/hitQz8cKnk5uVL4uxtuilNBcmCIUWZsd59Q7qPpWTKg/9bLQcLTLDn7dYnJcuXKw7oGi2gtKhRAbxUt0ZV9KrL8RWjdP+OoynnSxRSzLPpvnFDK3lu+sVp763VqRgldSqKPNmnibw7f6clHC0a1avI46qFEsvCD2sOyXPXXiKVypfTwUF9EZoN/erihHTWGjw3y+lxp/y1T+7ukmAzK/Abs7bJf5bstRlNNX/bcX290fOzXSqvKmNp+tWoCfCiyoXIZYkL9O15W4/LU1c1kUevbOT0uM5eW/UZigwLkQZV3TP/T0GqluneyRdq7/adPCcvDWhe4mUV1CSIFcuz4negozMtECDUr3r9MbZf3g2qlpfpw7pJbNTFBQU7vT5fejSpKu/c3MalPiG93lnk1nKqzLA3sb80e3G2ZBYYTfTawBa6tsSdbu9UVxJvbOXWjscj+jSWJ3o3tgQGtVaS6s+jarasQ4Sq9UnLytVNaCpMKpsPp1hWzbZHLd/g8Fz+s0KOpWbK3BE97IbI/afOSU8775c7Oizf8NFfsu5gslQID5X0rFyXyluU5i/NkYzsPF2TpzpjW9cyVYgI1a/jLmoG5q7j/pDZT1wul9Qs/neJen9fnLFZbro0XrqXcE6hQJNKZ1oABakvIvWFdDQlUw9lrlKhnN0vp7+f7+PyMavHuGc23bCQIJn1+OVyLjtPLql5YZh2wZCiuDukKGpSvOf7X6InwftokXs6ME+Yv0vXiDzUo6EcTj4v3cb9YXls66tXS7mQ4EK1M3+NvlJ3yC0qpCjLdp/S4ebqFjUKTZS3fO9pfX3joWQ94d+zP27U4UiNwlL9dKybwKypAGsvrKggVdQaVGfPZUtMZJikns/RIUWxDinKruNp0rjA0Ht13Dav/i6d61eSz4Z0tBuqVUhRWo/9XS/wqdbOUv2SVJOfcvOldeRtF8K0K1RIUfq996fLwUrV9qjavDZ14mRd0llZtOOkbvazfr6zn19BqmlRdaR39TlrD56VSlHlJKGKf3fipkYFQKlMW50kKRk58vos5wsKFvwSUH/hV65QTqKt/mIu6yHVnqTmt7nr85V6wUVPUP2KTqRlyv1f2HZiLgnVT0k1AZqdV6HxpTl61uGXB7TQIWrOlqPy7cok+Xu//cDj7GexYNsJGf7NWr2g59tzd+j71UrjahFP6/fcugnSbNnoKy2Bwuztm1rLzR3ii9Uxe9hXa6V8eIi8dVMbS61Mwc+aWhrD2azT09cdkie/K9x3asFTV0jvfy+2ue/3J3tIw6oV5MDpc/LktA3yv3s7WWox1ai0PuMXy8C2tXStmwp+0x/pqtflUjVtQQXCo/rK/mv3acnNz5d7/mli2/vGtYXCTfc3/5CasRHy/dCuLjWz5ZlM8ubs7fL1yoPy2JWN5Km+FxddNfr7m6ACwC0+WbxHEmdv11P4H0k5Ly//vEV+ebS7hIcF646p93dvIJHliu58a001QZV0SHWfS6pLrbgI+d/yAxeOlVBJrmhaVYb3auSxEPTmoFby7I/2+wD5ik1j+8rqA2dlz4l0+WblQdl76px4u7+f661rX1Stwm8bj+og9Nag1jJj/WEZ1L6OnkBw61Hn8/U4C9Xqy1w1p6nao/M5eTJ/63EZ++uFIe8lYT5+m1d+L3JRzo8Gt7csGqoCzA0fLSvymCpgrtx32hJiCjadqQ7OqkZ12Z5T8umf++Sdm1tLp9cv9IWyVis2Qo78M6O1u+Yy8umgMnHiRHn77bfl2LFj0qZNG/nggw+kU6dOTp9HUAG8izvX6VHV5puPpEjLWrF6in7rppJmNaL1gouK+kv05/VHZO2LV+nOttae/WGjNK0RLfd1r29zf2mCynu3tZWuDatIx9fnF+pX0yVxgW5ac6f4SpHSu1l1mbJsv1uPi8Dx1+grbZofi6tGTISseK534AaV7777Tu6++275+OOPpXPnzjJhwgT5/vvvZceOHVKtWrUin0tQAQJLVm6eZUi06q+gqsnVZG7FparY/9h+olDzgqKq2+/uUk+GXtFQKlcI130uIkKD9aKR1n9Vqj4Kqm9NwfuavODayCBFdVi2N/RaefDy+jpgmUd6XffBny7N5OuM6iOj+q+YJ6frmrjA8pezr7BuOkLZKGmnaL8IKiqcdOzYUT788EN9Oz8/X+Lj4+Wxxx6T0aNHF/lcggqA0lDV+S1enisP9Wggd11WTwchc3+JklLHOJ2eLWsOnNU1PQ2rltcdQa37EKjq+dCQIF37pP4LVh1uVR8QVWvyyq9bZeHTPaW+nQ6Sp9OzdB+HoVc00H0e1PBl9T+4Clc1YiPkwD/zragmnF0n0qV93YqWvhnJGTlSK67w8HZzXxRXqOYANarlxvZ1JC/fZDMrsbL+pat0H4cb2tXWr6VqnVQoKsptHeNl6qrC8+20jY+TWzvGS8eEivp1jqdmyf3d68sLquNzUJBP9GH67fHuejkM1STly5aPubLEUyP4fFDJzs6WqKgo+eGHH+T666+33D9kyBBJTk6Wn3/+2Wb/rKwsvVmfqAo1BBUAuDB0V43EKm5/AhXYVLh6b8EuHQ5a14nVAag4fYrsUV8vqkbKXmdpR1StlCq+syZENZpIzaej9rNeXFM1zalOslc2q6b72KgaNzUnkApkyeeznX7hFhzGrzrCzt50TL5bdVDOZGRLp/qVJTIsWLo0qCzdG1fVzY0Vo8L0z9w8ysfRaJ+CQ8/VyCU1aaDqP1U9Olxf79WsmsxYd1ivfK6GfP/xVE8d+lSfEjX/TVxUmDSoUkFumLRMbmxXW4Z0TdDHuuyNBXpouhptZh4xpeYAMs9pY6YmjVRzEqnwpz4rHRMq6fdcfXb+/ftOmTG8mxw6myGNq0XbTFvgbj4TVI4cOSK1a9eWZcuWSZcuXSz3P/PMM7J48WJZudI2rY8dO1ZeeeWVQschqAAA4DuKE1R8agr9MWPG6JMyb0lJpZ+eGwAAeC9DJ3yrUqWKhISEyPHjF6anNlO3a9SwncxICQ8P1xsAAAgMhtaolCtXTi699FJZsODiGG7VmVbdtm4KAgAAgcnwKfRHjhypO8926NBBz52ihiefO3dO7r33XqOLBgAAAj2o3HrrrXLy5El56aWX9IRvbdu2lTlz5kj16kVPXwwAAPyf4fOolAbzqAAA4Hv8dtQPAAAILAQVAADgtQgqAADAaxFUAACA1yKoAAAAr0VQAQAAXougAgAAvBZBBQAAeC3DZ6YtDfNcdWriGAAA4BvM39uuzDnr00ElLS1NX8bHxxtdFAAAUILvcTVDrd9Ooa9WWj5y5IhER0dLUFCQ29OeCkBJSUl+OT0/5+f7/P0cOT/f5u/nFwjnmOrB81PRQ4WUWrVqSXBwsP/WqKiTq1OnjkdfQ705/vgBNOP8fJ+/nyPn59v8/fwC4RxjPHR+zmpSzOhMCwAAvBZBBQAAeC2CigPh4eHy8ssv60t/xPn5Pn8/R87Pt/n7+QXCOYZ7yfn5dGdaAADg36hRAQAAXougAgAAvBZBBQAAeC2CCgAA8FoEFTsmTpwoCQkJEhERIZ07d5a///5bvNHYsWP1jLzWW7NmzSyPZ2ZmyvDhw6Vy5cpSoUIFGTRokBw/ftzmGAcPHpT+/ftLVFSUVKtWTUaNGiW5ubk2+yxatEjat2+ve343atRIpkyZ4pHzWbJkiQwYMEDPVKjOZcaMGTaPq37fL730ktSsWVMiIyOlT58+smvXLpt9zpw5I4MHD9aTE8XFxcn9998v6enpNvts3LhRLr/8cv3+qlkX33rrrUJl+f777/XPUu3TqlUrmTVrlsfP75577in0fl5zzTU+c36JiYnSsWNHPVO0+ixdf/31smPHDpt9yvIz6e7fY1fOr2fPnoXew6FDh/rE+SmTJk2S1q1bWyb46tKli8yePdsv3j9Xzs/X37+Cxo0bp89hxIgRvv0eqlE/uGjq1KmmcuXKmf773/+atmzZYnrwwQdNcXFxpuPHj5u8zcsvv2xq0aKF6ejRo5bt5MmTlseHDh1qio+PNy1YsMC0evVq02WXXWbq2rWr5fHc3FxTy5YtTX369DGtW7fONGvWLFOVKlVMY8aMseyzd+9eU1RUlGnkyJGmrVu3mj744ANTSEiIac6cOW4/H/X6zz//vOmnn35SI9FM06dPt3l83LhxptjYWNOMGTNMGzZsMP3rX/8y1a9f33T+/HnLPtdcc42pTZs2phUrVpj+/PNPU6NGjUy333675fGUlBRT9erVTYMHDzZt3rzZ9O2335oiIyNNn3zyiWWfv/76S5/jW2+9pc/5hRdeMIWFhZk2bdrk0fMbMmSILr/1+3nmzBmbfbz5/K6++mrT5MmT9euuX7/edO2115rq1q1rSk9PL/PPpCd+j105vyuuuEK/lvV7qN4TXzg/5ZdffjH99ttvpp07d5p27Nhheu655/RnQ52zr79/rpyfr79/1v7++29TQkKCqXXr1qYnnnjCcr8vvocElQI6depkGj58uOV2Xl6eqVatWqbExESTNwYV9aVlT3Jysv4F/P777y33bdu2TX9BLl++XN9WH8Dg4GDTsWPHLPtMmjTJFBMTY8rKytK3n3nmGR2GrN166636P21PKvhFnp+fb6pRo4bp7bfftjnH8PBw/WWsqF8Y9bxVq1ZZ9pk9e7YpKCjIdPjwYX37o48+MlWsWNFyfsqzzz5ratq0qeX2LbfcYurfv79NeTp37mx6+OGHPXZ+5qAycOBAh8/xpfNTTpw4ocu7ePHiMv9MlsXvccHzM3/RWX8pFORL52emPk+fffaZ371/Bc/Pn96/tLQ0U+PGjU3z5s2zOSdffQ9p+rGSnZ0ta9as0U0K1usJqdvLly8Xb6SaPlRTQoMGDXSTgKqyU9R55OTk2JyLquqvW7eu5VzUpar2r169umWfq6++Wi9EtWXLFss+1scw71PWP499+/bJsWPHbMqi1olQ1YnW56OaQzp06GDZR+2v3sOVK1da9unRo4eUK1fO5nxUFf7Zs2cNP2dVnaqqWps2bSrDhg2T06dPWx7ztfNLSUnRl5UqVSrTz2RZ/R4XPD+zr7/+WqpUqSItW7aUMWPGSEZGhuUxXzq/vLw8mTp1qpw7d043kfjb+1fw/Pzp/Rs+fLhuuilYDl99D316UUJ3O3XqlP7wWr9Birq9fft28TbqS1q1C6ovtaNHj8orr7yi+yZs3rxZf6mrLyv1xVbwXNRjirq0d67mx4raR31oz58/r/uKlAVzeeyVxbqs6kveWmhoqP4isd6nfv36hY5hfqxixYoOz9l8DE9R/VFuvPFGXb49e/bIc889J/369dO/2CEhIT51fmplc9Uu3q1bN/0fvvn1y+IzqQKZp3+P7Z2fcscdd0i9evX0Hw+qr9Czzz6rQ+JPP/3kM+e3adMm/cWt+jKoPgzTp0+X5s2by/r16/3i/XN0fv7y/k2dOlXWrl0rq1atKvSYr/4OElR8mPoSM1MdxFRwUb9k06ZNK7MAAfe57bbbLNfVXzTqPW3YsKGuZendu7f4EvUXnQrMS5cuFX/k6Pweeughm/dQdfxW750Knuq99AXqDx8VSlSN0Q8//CBDhgyRxYsXi79wdH4qrPj6+5eUlCRPPPGEzJs3T3dg9Rc0/VhR1X3qL9eCPaDV7Ro1aoi3Uym5SZMmsnv3bl1eVf2WnJzs8FzUpb1zNT9W1D6qx3xZhiFzeYp6b9TliRMnbB5XPdXVSBl3nHNZfwZUc576TKr305fO79FHH5WZM2fKwoULpU6dOpb7y+oz6enfY0fnZ4/640Gxfg+9/fzUX9xqFMell16qRzq1adNG3nvvPb95/xydnz+8f2vWrNH/R6jROKq2VW0qhL3//vv6uqrR8MX3kKBS4AOsPrwLFiywqeJVt63bML2VGqaqkr/6K0CdR1hYmM25qCpM1YfFfC7qUlWDWn/5qSSuPmzmqlC1j/UxzPuU9c9DNWeoD7h1WVQ1o+qbYX0+6hdQ/bKa/fHHH/o9NP+Ho/ZRw4RVO631+ai/slSziDed86FDh3QfFfV++sL5qT7C6ktcVaWrchVsgiqrz6Snfo+dnZ896i93xfo99Nbzc0QdOysry+ffP2fn5w/vX+/evXX5VLnNm+rTpvovmq/75HtY7O63fk4NqVIjSaZMmaJHWTz00EN6SJV1D2hv8dRTT5kWLVpk2rdvnx5yqoaTqWFkajSCeRiaGj75xx9/6GFoXbp00VvBYWh9+/bVwy3V0LKqVavaHYY2atQo3Tt84sSJHhuerHqqq+FwalMfzfHjx+vrBw4csAxPVu/Fzz//bNq4caMeIWNveHK7du1MK1euNC1dulT3fLcevqt6vavhu3fddZcekqjeb3V+BYfvhoaGmt555x19zmp0lTuG7xZ1fuqxp59+Wve8V+/n/PnzTe3bt9flz8zM9InzGzZsmB4+rj6T1sM7MzIyLPuU1WfSE7/Hzs5v9+7dpldffVWfl3oP1ee0QYMGph49evjE+SmjR4/Wo5hU+dXvmLqtRpX9/vvvPv/+OTs/f3j/7Ck4kskX30OCih1qTLh6I9UYcDXESs1Z4Y3UcLCaNWvqctauXVvfVr9sZuoL/JFHHtHD79SH6oYbbtD/sVrbv3+/qV+/fnquDRVyVPjJycmx2WfhwoWmtm3b6tdRv7hqLglPUK+jvsALbmrYrnmI8osvvqi/iNUvQO/evfVcCNZOnz6tv7grVKigh9Pde++9OgRYU3OwdO/eXR9D/dxUACpo2rRppiZNmuhzVsPw1NwLnjw/9WWn/mNQ/yGo0FCvXj0970DBX2pvPj9756Y2689LWX4m3f177Oz8Dh48qL/UKlWqpH/2ao4b9R+59Twc3nx+yn333ac/e+qY6rOofsfMIcXX3z9n5+cP758rQcUX38Mg9U/x62EAAAA8jz4qAADAaxFUAACA1yKoAAAAr0VQAQAAXougAgAAvBZBBQAAeC2CCgAA8FoEFQAA4LUIKgBKJSEhQSZMmODy/mo16KCgoEILowGAPcxMCwSYnj17Stu2bYsVLopy8uRJKV++vERFRbm0v1q9Va34rFZyVYHFCCos9erVS86ePatXHQfgvUKNLgAA76P+fsnLy9NLwztTtWrVYh1braxa2uXsAQQOmn6AAHLPPffI4sWL5b333tO1GWrbv3+/pTlm9uzZenn28PBwWbp0qezZs0cGDhyoaz8qVKggHTt2lPnz5xfZ9KOO89lnn8kNN9yga1kaN24sv/zyi8OmnylTpuhajblz58oll1yiX+eaa66Ro0ePWp6Tm5srjz/+uN6vcuXK8uyzz8qQIUPk+uuvd3iuBw4ckAEDBkjFihV1jU+LFi1k1qxZ+nxVbYqiHlNlUT8X81L0iYmJUr9+fYmMjJQ2bdrIDz/8UKjsv/32m7Ru3VoiIiLksssuk82bN7vl/QFQGEEFCCAqoHTp0kUefPBBHQTUFh8fb3l89OjRMm7cONm2bZv+Ik5PT5drr71WFixYIOvWrdMBQn35Hzx4sMjXeeWVV+SWW26RjRs36ucPHjxYN/c4kpGRIe+88458+eWXsmTJEn38p59+2vL4m2++KV9//bVMnjxZ/vrrL0lNTZUZM2YUWYbhw4dLVlaWPt6mTZv0MVQIUuf7448/6n127Nihfwbq56KokPK///1PPv74Y9myZYs8+eSTcuedd+pwZ23UqFHy73//W1atWqVrlNTPJCcnx8lPH0CJlGjNZQA+q+Cy7+Yl29V/BzNmzHD6/BYtWujl283q1atnevfddy231XFeeOEFy+309HR93+zZs21e6+zZs/q2Wh5e3d69e7flORMnTjRVr17dcltdf/vtty23c3Nz9fLxAwcOdFjOVq1amcaOHWv3sYJlUDIzM/Wy98uWLbPZ9/777zfdfvvtNs+bOnWq5fHTp0+bIiMjTd99910RPzUAJUUfFQAWHTp0sLmtalTGjh2rmzpUzYNqgjl//rzTGhVVG2Omml1iYmLkxIkTDvdXTUQNGza03K5Zs6Zl/5SUFDl+/Lh06tTJ8nhISIhuolJNNY6opqJhw4bJ77//Ln369JFBgwbZlKug3bt365qdq666qlDn33bt2tncp2qlzCpVqiRNmzbVtVAA3I+gAsAmVFhTzS/z5s3TzTKNGjXS/TZuuukm/eVdlLCwMJvbql9HUaHC3v6lHZD4wAMPyNVXX61DlgorqllHNdc89thjdvdXoUxR+9euXdvmMdVnB4Ax6KMCBBg16kaN6HGF6g+iOpqqjrGtWrXSo3VUZ9SyFBsbqzvzqv4gZqr8a9eudfpc1R9l6NCh8tNPP8lTTz0ln376qeVnYD6OWfPmzXUgUbVFKpRZb9b9eJQVK1ZYrqshzjt37tQdgQG4HzUqQIBRo3RWrlypA4fqXKqaLhxRI3bUl7zqLKpqOV588cUia0Y8RdWCqBoRFRqaNWsmH3zwgQ4IRc3DMmLECOnXr580adJE77tw4UJLmKhXr55+7syZM3VnX1VTFB0drWuQVAdadY7du3fXzU4qrKmmKzXKyOzVV1/Vo49UgHr++eelSpUqRY5AAlBy1KgAAUZ9Gas+HqoGQY1YKaq/yfjx4/UQ3q5du+qwoppS2rdvL2VNDUe+/fbb5e6779b9Q1TAUmVRw4MdUbUlauSPCidqtJIKLB999JF+TDXtqJFJapSTChuPPvqovv+1117TYUyFIvPzVFOQGq5sTY2MeuKJJ3Q/mWPHjsmvv/5qqaUB4F7MTAvA56gaDxUk1BBoFS7KCjPaAmWPph8AXk9N3qY6xF5xxRV6bpQPP/xQ9u3bJ3fccYfRRQPgYTT9APB6wcHBegZbNTNut27d9ARuaoZcOrAC/o+mHwAA4LWoUQEAAF6LoAIAALwWQQUAAHgtggoAAPBaBBUAAOC1CCoAAMBrEVQAAIDXIqgAAADxVv8Pjs4n0zRZjxcAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 摩 登 创 客 ： 与 智 能 手 机 和 平 板 电 脑 共 舞\n",
      "target： morton innovator: dancing with smartphones and peaceboard computers\n",
      "pred： morton innovator: dancing with smartphones and peaceboard computers\n",
      "\n",
      "input： p h o t o s h o p 电 商 产 品 精 修 实 战\n",
      "target： photoshop powermaker's prosthetic warfare\n",
      "pred： photoshop powermaker's prosthetic warfare\n",
      "\n",
      "input： 从 零 开 始 儿 童 口 风 琴 图 解 教 程\n",
      "target： from scratch , the children's accordion pedagogy .\n",
      "pred： from scratch , the children's accordion pedagogy .\n",
      "\n",
      "input： 无 师 自 通 7 ： 铅 笔 素 描 五 官 超 精 解 析 （ 修 订 版 ）\n",
      "target： untied 7: pen-showing five officers super precision (revised)\n",
      "pred： untied 7: pen-showing five officers super precision (revised)\n",
      "\n",
      "input： 移 动 计 算 及 应 用 开 发 技 术\n",
      "target： mobile computing and application development techniques\n",
      "pred： and applied development and application development techniques\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 国 之 重 器 出 版 工 程 欠 驱 动 机 器 人 动 力 学 与 控 制\n",
      "target： we're not driving robotic dynamics and controls .\n",
      "pred： we're not driving robotic dynamics and controls .\n",
      "\n",
      "input： 世 界 绘 画 经 典 教 程 — — 跟 巴 伯 学 素 描 （ 第 2 版 ）\n",
      "target： world painting classic curriculum - with barber's psychiatry . 2nd ed .\n",
      "pred： world painting classic curriculum (version 2) . 2nd ed . ed .)\n",
      "\n",
      "input： 新 媒 体 短 视 频 全 攻 略 前 期 拍 摄 后 期 处 理 广 告 变 现 营 销 推 广\n",
      "target： new media short video full-time , pre-shooting , post-processing , commercial realization , marketing .\n",
      "pred： video production self-learning manual , commercial realization , commercial realization , commercial realization , marketing .\n",
      "\n",
      "input： 拍 出 美 丽 的 风 景 旅 行 摄 影 入 门 与 提 高\n",
      "target： a beautiful view , an introduction to travel photography , and an improvement .\n",
      "pred： a beautiful view , an introduction to travel photography , and an improvement .\n",
      "\n",
      "input： 机 遇 之 门 以 色 列 闪 存 盘 之 父 的 创 业 心 路\n",
      "target： the door of opportunity , the entrepreneurship of the father of israel .\n",
      "pred： entrepreneurship for the door of opportunity in the digital economy of the father of israel .\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
